{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine tuning LLMs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 25000\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 25000\n",
       "    })\n",
       "    unsupervised: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 50000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"imdb\")\n",
    "dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': \"Terrible movie. Nuff Said.<br /><br />These Lines are Just Filler. The movie was bad. Why I have to expand on that I don't know. This is already a waste of my time. I just wanted to warn others. Avoid this movie. The acting sucks and the writing is just moronic. Bad in every way. The only nice thing about the movie are Deniz Akkaya's breasts. Even that was ruined though by a terrible and unneeded rape scene. The movie is a poorly contrived and totally unbelievable piece of garbage.<br /><br />OK now I am just going to rag on IMDb for this stupid rule of 10 lines of text minimum. First I waste my time watching this offal. Then feeling compelled to warn others I create an account with IMDb only to discover that I have to write a friggen essay on the film just to express how bad I think it is. Totally unnecessary.\",\n",
       " 'label': 0}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[\"train\"][100]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "28fc889e64364905b9d80db5a9198afb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/50000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
    "\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    return tokenizer(examples[\"text\"], padding = \"max_length\", truncation=True)\n",
    "\n",
    "\n",
    "tokenized_datasets = dataset.map(tokenize_function, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[101,\n",
       " 12008,\n",
       " 27788,\n",
       " 2523,\n",
       " 119,\n",
       " 151,\n",
       " 9435,\n",
       " 11455,\n",
       " 119,\n",
       " 133,\n",
       " 9304,\n",
       " 120,\n",
       " 135,\n",
       " 133,\n",
       " 9304,\n",
       " 120,\n",
       " 135,\n",
       " 1636,\n",
       " 12058,\n",
       " 1132,\n",
       " 2066,\n",
       " 17355,\n",
       " 9860,\n",
       " 119,\n",
       " 1109,\n",
       " 2523,\n",
       " 1108,\n",
       " 2213,\n",
       " 119,\n",
       " 2009,\n",
       " 146,\n",
       " 1138,\n",
       " 1106,\n",
       " 7380,\n",
       " 1113,\n",
       " 1115,\n",
       " 146,\n",
       " 1274,\n",
       " 112,\n",
       " 189,\n",
       " 1221,\n",
       " 119,\n",
       " 1188,\n",
       " 1110,\n",
       " 1640,\n",
       " 170,\n",
       " 5671,\n",
       " 1104,\n",
       " 1139,\n",
       " 1159,\n",
       " 119,\n",
       " 146,\n",
       " 1198,\n",
       " 1458,\n",
       " 1106,\n",
       " 11857,\n",
       " 1639,\n",
       " 119,\n",
       " 138,\n",
       " 6005,\n",
       " 2386,\n",
       " 1142,\n",
       " 2523,\n",
       " 119,\n",
       " 1109,\n",
       " 3176,\n",
       " 22797,\n",
       " 1105,\n",
       " 1103,\n",
       " 2269,\n",
       " 1110,\n",
       " 1198,\n",
       " 182,\n",
       " 14824,\n",
       " 7770,\n",
       " 119,\n",
       " 6304,\n",
       " 1107,\n",
       " 1451,\n",
       " 1236,\n",
       " 119,\n",
       " 1109,\n",
       " 1178,\n",
       " 3505,\n",
       " 1645,\n",
       " 1164,\n",
       " 1103,\n",
       " 2523,\n",
       " 1132,\n",
       " 14760,\n",
       " 9368,\n",
       " 138,\n",
       " 19610,\n",
       " 2315,\n",
       " 112,\n",
       " 188,\n",
       " 13016,\n",
       " 119,\n",
       " 2431,\n",
       " 1115,\n",
       " 1108,\n",
       " 9832,\n",
       " 1463,\n",
       " 1118,\n",
       " 170,\n",
       " 6434,\n",
       " 1105,\n",
       " 8362,\n",
       " 23063,\n",
       " 4902,\n",
       " 9372,\n",
       " 2741,\n",
       " 119,\n",
       " 1109,\n",
       " 2523,\n",
       " 1110,\n",
       " 170,\n",
       " 9874,\n",
       " 14255,\n",
       " 19091,\n",
       " 5790,\n",
       " 1105,\n",
       " 5733,\n",
       " 8362,\n",
       " 26438,\n",
       " 2727,\n",
       " 1104,\n",
       " 14946,\n",
       " 119,\n",
       " 133,\n",
       " 9304,\n",
       " 120,\n",
       " 135,\n",
       " 133,\n",
       " 9304,\n",
       " 120,\n",
       " 135,\n",
       " 10899,\n",
       " 1208,\n",
       " 146,\n",
       " 1821,\n",
       " 1198,\n",
       " 1280,\n",
       " 1106,\n",
       " 26133,\n",
       " 1113,\n",
       " 146,\n",
       " 18219,\n",
       " 1830,\n",
       " 1111,\n",
       " 1142,\n",
       " 4736,\n",
       " 3013,\n",
       " 1104,\n",
       " 1275,\n",
       " 2442,\n",
       " 1104,\n",
       " 3087,\n",
       " 5867,\n",
       " 119,\n",
       " 1752,\n",
       " 146,\n",
       " 5671,\n",
       " 1139,\n",
       " 1159,\n",
       " 2903,\n",
       " 1142,\n",
       " 1228,\n",
       " 1348,\n",
       " 119,\n",
       " 1599,\n",
       " 2296,\n",
       " 15957,\n",
       " 1106,\n",
       " 11857,\n",
       " 1639,\n",
       " 146,\n",
       " 2561,\n",
       " 1126,\n",
       " 3300,\n",
       " 1114,\n",
       " 146,\n",
       " 18219,\n",
       " 1830,\n",
       " 1178,\n",
       " 1106,\n",
       " 7290,\n",
       " 1115,\n",
       " 146,\n",
       " 1138,\n",
       " 1106,\n",
       " 3593,\n",
       " 170,\n",
       " 175,\n",
       " 17305,\n",
       " 4915,\n",
       " 10400,\n",
       " 1113,\n",
       " 1103,\n",
       " 1273,\n",
       " 1198,\n",
       " 1106,\n",
       " 6848,\n",
       " 1293,\n",
       " 2213,\n",
       " 146,\n",
       " 1341,\n",
       " 1122,\n",
       " 1110,\n",
       " 119,\n",
       " 8653,\n",
       " 1193,\n",
       " 14924,\n",
       " 119,\n",
       " 102,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_datasets['train'][100]['input_ids']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "small_train_dataset = tokenized_datasets[\"train\"].shuffle(seed=42).select(range(200))\n",
    "small_eval_dataset = tokenized_datasets[\"test\"].shuffle(seed=42).select(range(200))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_args = TrainingArguments(output_dir=\"test_trainer\", num_train_epochs = 2, evaluation_strategy=\"epoch\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=small_train_dataset,\n",
    "    eval_dataset=small_eval_dataset,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "716f34ddc83b43a09001eeff8f92e95d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f19a8c53696244b98ce054fd2bdc65ce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.6720085144042969, 'eval_accuracy': 0.58, 'eval_runtime': 609.7916, 'eval_samples_per_second': 0.328, 'eval_steps_per_second': 0.041, 'epoch': 1.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "48988d14f0854b4eae7cb398a3b697fa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.5366445183753967, 'eval_accuracy': 0.82, 'eval_runtime': 524.186, 'eval_samples_per_second': 0.382, 'eval_steps_per_second': 0.048, 'epoch': 2.0}\n",
      "{'train_runtime': 3886.3326, 'train_samples_per_second': 0.103, 'train_steps_per_second': 0.013, 'train_loss': 0.6406719207763671, 'epoch': 2.0}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=50, training_loss=0.6406719207763671, metrics={'train_runtime': 3886.3326, 'train_samples_per_second': 0.103, 'train_steps_per_second': 0.013, 'train_loss': 0.6406719207763671, 'epoch': 2.0})"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model('models/sentiment-classifier')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7aa8ec9515564d9e86d90f4aa9123457",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6cfa7ed476914083907c2a5af6a9bbad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be504167c51a4a26883618073a425867",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model.bin:   0%|          | 0.00/433M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b27c5f62868240dca4e9f50c34b9f78f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "training_args.bin:   0%|          | 0.00/4.03k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/vaalto/test_trainer/tree/main/'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.push_to_hub(\"valen/sentiment-classifier\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import BertConfig, BertModel\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained('models/sentiment-classifier')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = tokenizer(\"I cannot stand it anymore!\", return_tensors=\"pt\")\n",
    "\n",
    "outputs = model(**inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SequenceClassifierOutput(loss=None, logits=tensor([[ 0.6467, -0.0041]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor([[0.6571879  0.34281212]], shape=(1, 2), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "predictions = tf.math.softmax(outputs.logits.detach(), axis=-1)\n",
    "print(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'label': 'LABEL_0', 'score': 0.6545460224151611}]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "classifier = pipeline(task ='sentiment-analysis', model = model, tokenizer = tokenizer)\n",
    "classifier('I cannot believe you did it again!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AutoTrain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "#reducing the rows of the dataset to leverage the free autotrain\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "small_dataset = dataset[\"train\"].shuffle(seed=42).select(range(3000))\n",
    "small_dataset = small_dataset.to_pandas()\n",
    "\n",
    "small_dataset['label'] = small_dataset['label'].replace({0: 'Negative', 1: 'Positive'})\n",
    "\n",
    "small_dataset.head()\n",
    "\n",
    "small_dataset.to_csv('imdb_small.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b1c4b7d4ba654a20ac0e0fc5ad8accbc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)okenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "def6613540434451b1987e769f483094",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b040a14a452a416a9b02bde41858cde0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "98243875b6c14c578feba62de25e1187",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "\n",
    "\n",
    "import os\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "#os.environ[\"HUGGINGFACEHUB_API_TOKEN\"]\n",
    "huggingfacehub_api_token = os.environ['HUGGINGFACEHUB_API_TOKEN']\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"vaalto/autotrain-imdb-sentiment-analysis-89579143987\", token = huggingfacehub_api_token)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"vaalto/autotrain-imdb-sentiment-analysis-89579143987\", token = huggingfacehub_api_token)\n",
    "\n",
    "inputs = tokenizer(\"This movie put me in a very nostalgic mood\", return_tensors=\"pt\")\n",
    "\n",
    "outputs = model(**inputs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SequenceClassifierOutput(loss=None, logits=tensor([[-0.8675,  1.3277]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor([[0.10018114 0.89981884]], shape=(1, 2), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "predictions = tf.math.softmax(outputs.logits.detach(), axis=-1)\n",
    "print(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: 'Negative', 1: 'Positive'}"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.config.id2label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'label': 'Negative', 'score': 0.6326483488082886}]"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "classifier = pipeline(task ='sentiment-analysis', model = model, tokenizer = tokenizer)\n",
    "classifier('I cannot believe you did it again!')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
